{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "fd5dcb5a",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch import nn\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "24613d7c",
   "metadata": {},
   "outputs": [],
   "source": [
    "gru1 = nn.GRU(2, 4, 1, batch_first=True)\n",
    "gru_test_input = torch.zeros([1,3,2])+0.5\n",
    "gru1_output = gru1(gru_test_input)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "81ac42fe",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(tensor([[[-0.3352, -0.0475, -0.0508,  0.0806],\n",
       "          [-0.3978, -0.0550, -0.0398,  0.0720],\n",
       "          [-0.4103, -0.0571, -0.0257,  0.0596]]], grad_fn=<TransposeBackward1>),\n",
       " tensor([[[-0.4103, -0.0571, -0.0257,  0.0596]]], grad_fn=<StackBackward>))"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "gru1_output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "id": "5be1c088",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])"
      ]
     },
     "execution_count": 44,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "gru1.bias_ih_l0.data.fill_(0)\n",
    "gru1.bias_hh_l0.data.fill_(0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "5156b364",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([1, 3, 4])"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "gru1_output = gru1(gru_test_input)\n",
    "gru1_output[0].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "id": "31dfa209",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "y\n"
     ]
    }
   ],
   "source": [
    "if isinstance(gru1,nn.GRU):\n",
    "    print('y')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "037748c9",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([12, 2])"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "gru1.weight_ih_l0.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "c1652b6a",
   "metadata": {},
   "outputs": [],
   "source": [
    "def convert_gru_input_kernel(kernel):\n",
    "    kernel_r, kernel_z, kernel_h = np.vsplit(kernel, 3)\n",
    "    kernels = [kernel_z, kernel_r, kernel_h]\n",
    "    return torch.tensor(np.hstack([k.reshape(k.T.shape) for k in kernels]))\n",
    "\n",
    "def convert_gru_recurrent_kernel(kernel):\n",
    "    kernel_r, kernel_z, kernel_h = np.vsplit(kernel, 3)\n",
    "    kernels = [kernel_z, kernel_r, kernel_h]\n",
    "    return torch.tensor(np.hstack(kernels))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "2c5bbe5f",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([2, 12])\n"
     ]
    }
   ],
   "source": [
    "new_weight_ih_l0 = convert_gru_input_kernel(gru1.weight_ih_l0.detach().numpy())\n",
    "print(new_weight_ih_l0.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "6bb3c098",
   "metadata": {},
   "outputs": [],
   "source": [
    "import keras\n",
    "import tensorflow\n",
    "from keras import layers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "a310a67f",
   "metadata": {},
   "outputs": [],
   "source": [
    "keras_conv1 = layers.Conv1D(3, 2, padding='same', activation='tanh', name='feature_conv1')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "9afb23f3",
   "metadata": {},
   "outputs": [],
   "source": [
    "input_shape = (1,3,4)\n",
    "x = tensorflow.random.normal(input_shape)\n",
    "y = keras_conv1(x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "72c5296e",
   "metadata": {},
   "outputs": [],
   "source": [
    "weights =  keras_conv1.get_weights()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "18315a2e",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(2, 4, 3)"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "weights[0].shape\n",
    "#input 1\n",
    "#kernal_size 0\n",
    "#output 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "757e6b12",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(3,)"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "weights[1].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "6cd8bedc",
   "metadata": {},
   "outputs": [],
   "source": [
    "keras_gru1 = layers.GRU(4,return_sequences=True, recurrent_activation=\"sigmoid\", reset_after='true', name='gru_a',)\n",
    "input_shape = (1,3,2)\n",
    "x = tensorflow.random.normal(input_shape)\n",
    "y = keras_gru1(x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "c3b4b238",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(2, 12)"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "weights =  keras_gru1.get_weights()\n",
    "weights[0].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "c5ece6c4",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[ 0.50630426 -0.18125176 -0.22455446 -0.22690475 -0.21649115 -0.19073492\n",
      "  -0.01370192 -0.47307855  0.05561786  0.21302928  0.0155199   0.50292975]\n",
      " [-0.05305615  0.01354556  0.02428324 -0.10544971  0.5824456   0.23123857\n",
      "   0.18460734 -0.3195278  -0.5989489  -0.13466732 -0.20463274  0.19403724]\n",
      " [-0.03756262  0.24331398 -0.7099939  -0.00343641  0.21649931 -0.02412327\n",
      "  -0.04613936 -0.11119071 -0.07995344  0.18405285  0.4935046  -0.2992046 ]\n",
      " [ 0.30607358  0.029637   -0.01524288 -0.06086059  0.20136751 -0.4949669\n",
      "  -0.19576669 -0.22203046  0.10759834 -0.52205724 -0.24321698 -0.4301926 ]]\n"
     ]
    }
   ],
   "source": [
    "print(weights[1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "b44c780b",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "TensorShape([1, 3, 4])"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "y.shape"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
